{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Benchmark mixed precision training on Cifar100\n",
    "\n",
    "In this notebook we will benchmark 1) native PyTorch mixed precision module [`torch.cuda.amp`](https://pytorch.org/docs/master/amp.html) and 2) NVidia/Apex package.\n",
    "\n",
    "We will train Wide-ResNet model on Cifar100 dataset using Turing enabled GPU and compare training times.\n",
    "\n",
    "**TL;DR**\n",
    "\n",
    "The ranking is the following:\n",
    "- 1st place: Nvidia/Apex \"O2\"\n",
    "- 2nd place: `torch.cuda.amp`: autocast and scaler\n",
    "- 3rd place: Nvidia/Apex \"O1\"\n",
    "- 4th place: fp32\n",
    "\n",
    "According to @mcarilli: \"Native amp is more like a faster, better integrated, locally enabled O1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Installations and setup\n",
    "\n",
    "1) Recently added [`torch.cuda.amp`](https://pytorch.org/docs/master/notes/amp_examples.html#working-with-multiple-models-losses-and-optimizers) module to perform automatic mixed precision training instead of using Nvidia/Apex package is available in PyTorch >=1.6.0. At the moment of writing, we need to install nightly release to benefit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install --pre --upgrade torch==1.6.0.dev20200411+cu101 torchvision -f https://download.pytorch.org/whl/nightly/cu101/torch_nightly.html\n",
    "# !pip install --pre --upgrade pytorch-ignite \n",
    "# !pip install --upgrade pynvml fire"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2) Let's install Nvidia/Apex package:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install --upgrade --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" git+https://github.com/NVIDIA/apex/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('1.6.0.dev20200411+cu101', '0.6.0+cu101', '0.4.0.dev20200411')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import ignite\n",
    "torch.__version__, torchvision.__version__, ignite.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3) The scripts we will execute are located in `ignite/examples/contrib/cifar100_amp_benchmark` of github repository. Let's clone the repository and setup PYTHONPATH to execute benchmark scripts:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into '/tmp/ignite'...\n",
      "remote: Enumerating objects: 5534, done.\u001b[K\n",
      "remote: Total 5534 (delta 0), reused 0 (delta 0), pack-reused 5534\u001b[K\n",
      "Receiving objects: 100% (5534/5534), 21.83 MiB | 14.43 MiB/s, done.\n",
      "Resolving deltas: 100% (3458/3458), done.\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/pytorch/ignite.git /tmp/ignite\n",
    "scriptspath=\"/tmp/ignite/examples/contrib/cifar100_amp_benchmark/\"\n",
    "setup=\"cd {} && export PYTHONPATH=$PWD:$PYTHONPATH\".format(scriptspath)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "4) Download dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /tmp/cifar100/cifar-100-python.tar.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d65be5d32a414180afa8ac6bafb01599",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /tmp/cifar100/cifar-100-python.tar.gz to /tmp/cifar100/\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset CIFAR100\n",
       "    Number of datapoints: 50000\n",
       "    Root location: /tmp/cifar100/\n",
       "    Split: Train"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torchvision.datasets.cifar import CIFAR100\n",
    "CIFAR100(root=\"/tmp/cifar100/\", train=True, download=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training in fp32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Epoch [1/20]: [195/195] 100%|████████████████████, batch loss=4.53 [00:16<00:00]\n",
      "Epoch [2/20]: [195/195] 100%|████████████████████, batch loss=4.25 [00:16<00:00]\n",
      "Epoch [3/20]: [195/195] 100%|████████████████████, batch loss=4.19 [00:16<00:00]\n",
      "Epoch [4/20]: [195/195] 100%|████████████████████, batch loss=3.94 [00:16<00:00]\n",
      "Epoch [5/20]: [195/195] 100%|████████████████████, batch loss=3.98 [00:16<00:00]\n",
      "Epoch [6/20]: [195/195] 100%|████████████████████, batch loss=3.91 [00:16<00:00]\n",
      "Epoch [7/20]: [195/195] 100%|█████████████████████, batch loss=3.8 [00:17<00:00]\n",
      "Epoch [8/20]: [195/195] 100%|████████████████████, batch loss=3.68 [00:17<00:00]\n",
      "Epoch [9/20]: [195/195] 100%|████████████████████, batch loss=3.52 [00:17<00:00]\n",
      "Epoch [10/20]: [195/195] 100%|███████████████████, batch loss=3.61 [00:17<00:00]\n",
      "Epoch [11/20]: [195/195] 100%|███████████████████, batch loss=3.63 [00:17<00:00]\n",
      "Epoch [12/20]: [195/195] 100%|███████████████████, batch loss=3.63 [00:17<00:00]\n",
      "Epoch [13/20]: [195/195] 100%|███████████████████, batch loss=3.39 [00:17<00:00]\n",
      "Epoch [14/20]: [195/195] 100%|███████████████████, batch loss=3.58 [00:17<00:00]\n",
      "Epoch [15/20]: [195/195] 100%|███████████████████, batch loss=3.45 [00:17<00:00]\n",
      "Epoch [16/20]: [195/195] 100%|███████████████████, batch loss=3.11 [00:17<00:00]\n",
      "Epoch [17/20]: [195/195] 100%|███████████████████, batch loss=3.28 [00:17<00:00]\n",
      "Epoch [18/20]: [195/195] 100%|███████████████████, batch loss=3.25 [00:17<00:00]\n",
      "Epoch [19/20]: [195/195] 100%|███████████████████, batch loss=3.18 [00:17<00:00]\n",
      "Epoch [20/20]: [195/195] 100%|███████████████████, batch loss=3.26 [00:17<00:00]\n",
      "- Mean elapsed time for 1 epoch: 17.8373235606472\n",
      "- Metrics:\n",
      "\tTrain Accuracy: 0.24\n",
      "\tTrain Loss: 3.07\n",
      "\tTest Accuracy: 0.24\n",
      "\tTest Loss: 3.13\n"
     ]
    }
   ],
   "source": [
    "!{setup} && python benchmark_fp32.py /tmp/cifar100/ --batch_size=256 --max_epochs=20"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training with `torch.cuda.amp`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Epoch [1/20]: [195/195] 100%|████████████████████, batch loss=4.62 [00:10<00:00]\n",
      "Epoch [2/20]: [195/195] 100%|████████████████████, batch loss=4.22 [00:10<00:00]\n",
      "Epoch [3/20]: [195/195] 100%|████████████████████, batch loss=4.22 [00:10<00:00]\n",
      "Epoch [4/20]: [195/195] 100%|████████████████████, batch loss=3.96 [00:10<00:00]\n",
      "Epoch [5/20]: [195/195] 100%|████████████████████, batch loss=3.88 [00:10<00:00]\n",
      "Epoch [6/20]: [195/195] 100%|████████████████████, batch loss=3.93 [00:10<00:00]\n",
      "Epoch [7/20]: [195/195] 100%|████████████████████, batch loss=3.71 [00:10<00:00]\n",
      "Epoch [8/20]: [195/195] 100%|████████████████████, batch loss=3.73 [00:10<00:00]\n",
      "Epoch [9/20]: [195/195] 100%|████████████████████, batch loss=3.61 [00:10<00:00]\n",
      "Epoch [10/20]: [195/195] 100%|███████████████████, batch loss=3.52 [00:10<00:00]\n",
      "Epoch [11/20]: [195/195] 100%|███████████████████, batch loss=3.39 [00:10<00:00]\n",
      "Epoch [12/20]: [195/195] 100%|███████████████████, batch loss=3.35 [00:10<00:00]\n",
      "Epoch [13/20]: [195/195] 100%|███████████████████, batch loss=3.31 [00:10<00:00]\n",
      "Epoch [14/20]: [195/195] 100%|███████████████████, batch loss=3.43 [00:10<00:00]\n",
      "Epoch [15/20]: [195/195] 100%|███████████████████, batch loss=3.29 [00:10<00:00]\n",
      "Epoch [16/20]: [195/195] 100%|███████████████████, batch loss=3.46 [00:10<00:00]\n",
      "Epoch [17/20]: [195/195] 100%|███████████████████, batch loss=3.24 [00:10<00:00]\n",
      "Epoch [18/20]: [195/195] 100%|███████████████████, batch loss=3.24 [00:10<00:00]\n",
      "Epoch [19/20]: [195/195] 100%|███████████████████, batch loss=3.37 [00:10<00:00]\n",
      "Epoch [20/20]: [195/195] 100%|███████████████████, batch loss=3.07 [00:10<00:00]\n",
      "- Mean elapsed time for 1 epoch: 10.985005500703119\n",
      "- Metrics:\n",
      "\tTrain Accuracy: 0.25\n",
      "\tTrain Loss: 3.04\n",
      "\tTest Accuracy: 0.24\n",
      "\tTest Loss: 3.13\n"
     ]
    }
   ],
   "source": [
    "!{setup} && python benchmark_torch_cuda_amp.py /tmp/cifar100/ --batch_size=256 --max_epochs=20"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training with `Nvidia/apex`\n",
    "\n",
    "\n",
    "- we check 2 optimization levels: \"O1\" and \"O2\"\n",
    "    - \"O1\" optimization level: automatic casts arount Pytorch functions and tensor methods\n",
    "    - \"O2\" optimization level: fp16 training with fp32 batchnorm and fp32 master weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.\n",
      "\n",
      "Defaults for this optimization level are:\n",
      "enabled                : True\n",
      "opt_level              : O1\n",
      "cast_model_type        : None\n",
      "patch_torch_functions  : True\n",
      "keep_batchnorm_fp32    : None\n",
      "master_weights         : None\n",
      "loss_scale             : dynamic\n",
      "Processing user overrides (additional kwargs that are not None)...\n",
      "After processing overrides, optimization options are:\n",
      "enabled                : True\n",
      "opt_level              : O1\n",
      "cast_model_type        : None\n",
      "patch_torch_functions  : True\n",
      "keep_batchnorm_fp32    : None\n",
      "master_weights         : None\n",
      "loss_scale             : dynamic\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0\n",
      "Epoch [1/20]: [1/195]   1%|                      , batch loss=5.03 [00:00<00:00]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0\n",
      "Epoch [1/20]: [1/195]   1%|                      , batch loss=5.12 [00:00<00:12]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0\n",
      "Epoch [1/20]: [3/195]   2%|▎                     , batch loss=4.95 [00:00<00:11]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0\n",
      "Epoch [1/20]: [195/195] 100%|████████████████████, batch loss=4.56 [00:10<00:00]\n",
      "Epoch [2/20]: [195/195] 100%|████████████████████, batch loss=4.18 [00:10<00:00]\n",
      "Epoch [3/20]: [195/195] 100%|████████████████████, batch loss=4.06 [00:10<00:00]\n",
      "Epoch [4/20]: [195/195] 100%|████████████████████, batch loss=4.03 [00:10<00:00]\n",
      "Epoch [5/20]: [195/195] 100%|████████████████████, batch loss=4.01 [00:10<00:00]\n",
      "Epoch [6/20]: [195/195] 100%|████████████████████, batch loss=4.02 [00:10<00:00]\n",
      "Epoch [7/20]: [195/195] 100%|████████████████████, batch loss=3.72 [00:10<00:00]\n",
      "Epoch [8/20]: [195/195] 100%|████████████████████, batch loss=3.62 [00:10<00:00]\n",
      "Epoch [9/20]: [195/195] 100%|████████████████████, batch loss=3.72 [00:10<00:00]\n",
      "Epoch [10/20]: [195/195] 100%|███████████████████, batch loss=3.57 [00:10<00:00]\n",
      "Epoch [11/20]: [195/195] 100%|████████████████████, batch loss=3.5 [00:10<00:00]\n",
      "Epoch [12/20]: [195/195] 100%|████████████████████, batch loss=3.6 [00:10<00:00]\n",
      "Epoch [13/20]: [195/195] 100%|███████████████████, batch loss=3.28 [00:10<00:00]\n",
      "Epoch [14/20]: [195/195] 100%|███████████████████, batch loss=3.43 [00:10<00:00]\n",
      "Epoch [15/20]: [195/195] 100%|███████████████████, batch loss=3.45 [00:10<00:00]\n",
      "Epoch [16/20]: [195/195] 100%|███████████████████, batch loss=3.13 [00:10<00:00]\n",
      "Epoch [17/20]: [195/195] 100%|███████████████████, batch loss=3.26 [00:10<00:00]\n",
      "Epoch [18/20]: [195/195] 100%|████████████████████, batch loss=3.1 [00:10<00:00]\n",
      "Epoch [19/20]: [195/195] 100%|███████████████████, batch loss=3.19 [00:10<00:00]\n",
      "Epoch [20/20]: [195/195] 100%|███████████████████, batch loss=3.18 [00:10<00:00]\n",
      "- Mean elapsed time for 1 epoch: 11.186348466202617\n",
      "- Metrics:\n",
      "\tTrain Accuracy: 0.24\n",
      "\tTrain Loss: 3.07\n",
      "\tTest Accuracy: 0.24\n",
      "\tTest Loss: 3.11\n"
     ]
    }
   ],
   "source": [
    "!{setup} && python benchmark_nvidia_apex.py /tmp/cifar100/ --batch_size=256 --max_epochs=20 --opt=\"O1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.\n",
      "\n",
      "Defaults for this optimization level are:\n",
      "enabled                : True\n",
      "opt_level              : O2\n",
      "cast_model_type        : torch.float16\n",
      "patch_torch_functions  : False\n",
      "keep_batchnorm_fp32    : True\n",
      "master_weights         : True\n",
      "loss_scale             : dynamic\n",
      "Processing user overrides (additional kwargs that are not None)...\n",
      "After processing overrides, optimization options are:\n",
      "enabled                : True\n",
      "opt_level              : O2\n",
      "cast_model_type        : torch.float16\n",
      "patch_torch_functions  : False\n",
      "keep_batchnorm_fp32    : True\n",
      "master_weights         : True\n",
      "loss_scale             : dynamic\n",
      "Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0\n",
      "Epoch [1/20]: [1/195]   1%|                      , batch loss=5.01 [00:00<00:00]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0\n",
      "Epoch [1/20]: [1/195]   1%|                      , batch loss=4.98 [00:00<00:11]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0\n",
      "Epoch [1/20]: [3/195]   2%|▎                     , batch loss=4.97 [00:00<00:10]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0\n",
      "Epoch [1/20]: [195/195] 100%|████████████████████, batch loss=4.51 [00:09<00:00]\n",
      "Epoch [2/20]: [195/195] 100%|████████████████████, batch loss=4.31 [00:09<00:00]\n",
      "Epoch [3/20]: [195/195] 100%|████████████████████, batch loss=4.01 [00:09<00:00]\n",
      "Epoch [4/20]: [195/195] 100%|████████████████████, batch loss=3.98 [00:09<00:00]\n",
      "Epoch [5/20]: [195/195] 100%|████████████████████, batch loss=3.86 [00:09<00:00]\n",
      "Epoch [6/20]: [195/195] 100%|█████████████████████, batch loss=3.7 [00:09<00:00]\n",
      "Epoch [7/20]: [195/195] 100%|████████████████████, batch loss=3.62 [00:09<00:00]\n",
      "Epoch [8/20]: [195/195] 100%|█████████████████████, batch loss=3.6 [00:09<00:00]\n",
      "Epoch [9/20]: [195/195] 100%|████████████████████, batch loss=3.78 [00:09<00:00]\n",
      "Epoch [10/20]: [195/195] 100%|███████████████████, batch loss=3.45 [00:09<00:00]\n",
      "Epoch [11/20]: [195/195] 100%|████████████████████, batch loss=3.5 [00:09<00:00]\n",
      "Epoch [12/20]: [195/195] 100%|███████████████████, batch loss=3.52 [00:09<00:00]\n",
      "Epoch [13/20]: [195/195] 100%|███████████████████, batch loss=3.46 [00:09<00:00]\n",
      "Epoch [14/20]: [195/195] 100%|███████████████████, batch loss=3.51 [00:09<00:00]\n",
      "Epoch [15/20]: [195/195] 100%|███████████████████, batch loss=3.44 [00:09<00:00]\n",
      "Epoch [16/20]: [195/195] 100%|███████████████████, batch loss=3.29 [00:09<00:00]\n",
      "Epoch [17/20]: [195/195] 100%|███████████████████, batch loss=3.29 [00:09<00:00]\n",
      "Epoch [18/20]: [195/195] 100%|███████████████████, batch loss=3.33 [00:09<00:00]\n",
      "Epoch [19/20]: [195/195] 100%|███████████████████, batch loss=3.14 [00:09<00:00]\n",
      "Epoch [20/20]: [195/195] 100%|███████████████████, batch loss=3.25 [00:09<00:00]\n",
      "- Mean elapsed time for 1 epoch: 10.30258504075464\n",
      "- Metrics:\n",
      "\tTrain Accuracy: 0.25\n",
      "\tTrain Loss: 3.06\n",
      "\tTest Accuracy: 0.25\n",
      "\tTest Loss: 3.11\n"
     ]
    }
   ],
   "source": [
    "!{setup} && python benchmark_nvidia_apex.py /tmp/cifar100/ --batch_size=256 --max_epochs=20 --opt=\"O2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
